{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def loadDataSet(filename):\n",
    "    \"\"\"\n",
    "    读取数据集\n",
    "\n",
    "    Args:\n",
    "        filename: 文件名\n",
    "    Returns:\n",
    "        dataMat: 数据样本矩阵\n",
    "    \"\"\"\n",
    "    dataMat = []\n",
    "    with open(filename, 'rb') as f:\n",
    "        for line in f:\n",
    "            eles = map(float,line.strip().split('\\t'))\n",
    "            dataMat.append(eles)\n",
    "    return dataMat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def distEclud(vecA, vecB):\n",
    "    \"\"\"\n",
    "    计算两向量的欧氏距离\n",
    "\n",
    "    Args:\n",
    "        vecA: 向量A\n",
    "        vecB: 向量B\n",
    "    Returns:\n",
    "        欧式距离\n",
    "    \"\"\"\n",
    "    return np.sqrt(np.sum(np.power(vecA - vecB, 2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def randCent(dataSet, k):\n",
    "    \"\"\"\n",
    "    随机生成k个聚类中心\n",
    "\n",
    "    Args:\n",
    "        dataSet: 数据集\n",
    "        k: 簇数目\n",
    "    Returns:\n",
    "        centroids: 聚类中心矩阵\n",
    "    \"\"\"\n",
    "    m, _ = dataSet.shape\n",
    "    # 随机从数据集中选几个作为初始聚类中心\n",
    "    centroids = dataSet.take(np.random.choice(80,k), axis=0)\n",
    "    return centroids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def kMeans(dataSet, k, maxIter = 5):\n",
    "    \"\"\"\n",
    "    K-Means\n",
    "\n",
    "    Args:\n",
    "        dataSet: 数据集\n",
    "        k: 聚类数\n",
    "    Returns:\n",
    "        centroids: 聚类中心\n",
    "        clusterAssment: 点分配结果\n",
    "    \"\"\"\n",
    "    # 随机初始化聚类中心\n",
    "    centroids = randCent(dataSet, k)\n",
    "    init_centroids = centroids.copy()\n",
    "    \n",
    "    m, n = np.shape(dataSet)\n",
    "    \n",
    "    # 点分配结果： 第一列指明样本所在的簇，第二列指明该样本到聚类中心的距离\n",
    "    clusterAssment = np.mat(np.zeros((m, 2)))\n",
    "    \n",
    "    # 标识聚类中心是否仍在改变\n",
    "    clusterChanged = True\n",
    "    \n",
    "    # 直至聚类中心不再变化\n",
    "    iterCount = 0\n",
    "    while clusterChanged and iterCount < maxIter:\n",
    "        iterCount += 1\n",
    "        clusterChanged = False\n",
    "        # 分配样本到簇\n",
    "        for i in range(m):\n",
    "            # 计算第i个样本到各个聚类中心的距离\n",
    "            minIndex = 0\n",
    "            minDist = np.inf\n",
    "            for j in range(k):\n",
    "                dist = distEclud(dataSet[i, :],  centroids[j, :])\n",
    "                if(dist < minDist):\n",
    "                    minIndex = j\n",
    "                    minDist = dist\n",
    "            # 任何一个样本的类簇分配发生变化则认为变化\n",
    "            if(clusterAssment[i, 0] != minIndex):\n",
    "                clusterChanged = True\n",
    "            clusterAssment[i, :] = minIndex, minDist**2\n",
    "            \n",
    "        # 刷新聚类中心: 移动聚类中心到所在簇的均值位置\n",
    "        for cent in range(k):\n",
    "            # 通过数组过滤获得簇中的点\n",
    "            ptsInCluster = dataSet[np.nonzero(\n",
    "                clusterAssment[:, 0].A == cent)[0]]\n",
    "            if ptsInCluster.shape[0] > 0:\n",
    "                # 计算均值并移动\n",
    "                centroids[cent, :] = np.mean(ptsInCluster, axis=0)\n",
    "    return centroids, clusterAssment, init_centroids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataMat = np.mat(loadDataSet('data/testSet.txt'))\n",
    "m,n = np.shape(dataMat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_k = 2\n",
    "centroids, clusterAssment, init_centroids = kMeans(dataMat, set_k)\n",
    "\n",
    "clusterCount = np.shape(centroids)[0]\n",
    "clusterCount"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 注意，我们这里只设定了最多四个簇的样式，所以前面如果set_k设置超过了4，后面就会出现index error\n",
    "patterns = ['o', 'D', '^', 's']\n",
    "colors = ['b', 'g', 'y', 'black']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure()\n",
    "title = 'kmeans with k={}'.format(set_k)\n",
    "ax = fig.add_subplot(111, title=title)\n",
    "for k in range(clusterCount):\n",
    "    # 绘制聚类中心\n",
    "    ax.scatter(centroids[k, 0], centroids[k, 1], color='r', marker='+', linewidth=20)\n",
    "    # 绘制初始聚类中心\n",
    "    ax.scatter(init_centroids[k,0], init_centroids[k,1], color='purple', marker='*', linewidths=10)\n",
    "    for i in range(m):\n",
    "        # 绘制属于该聚类中心的样本\n",
    "        ptsInCluster = dataMat[np.nonzero(clusterAssment[:, 0].A==k)[0]]\n",
    "        ax.scatter(ptsInCluster[:, 0].flatten().A[0], ptsInCluster[:, 1].flatten().A[0], marker=patterns[k], color=colors[k])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
